from typing import *
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly
from pykalman import KalmanFilter
class KalmanFilter1D:
"""
Simple Kalman filter implementation for 1D signal
that doesn't take into account control-input model (B).
"""
def __init__(self, init_state, init_state_covariance,
state_transition_model=1, observation_model=1,
process_noise_covariance=1, observation_noise_covariance=1):
self.transition_model = state_transition_model
self.observation_model = observation_model
self.process_cov = process_noise_covariance
self.observation_cov = observation_noise_covariance
self.state_estimate = init_state
self.estimate_cov = init_state_covariance
self.predicted_state: float
self.predicted_estimate_cov: float
def _predict(self):
self.predicted_state = self.transition_model * self.state_estimate
self.predicted_estimate_cov = self.transition_model * self.estimate_cov \
* self.transition_model + self.process_cov
def _update(self, observation: float):
innovation = observation - self.observation_model * self.predicted_state
innovation_cov = self.observation_model * self.predicted_estimate_cov \
* self.observation_model + self.observation_cov
kalman_gain = self.predicted_estimate_cov * self.observation_model \
/ innovation_cov
self.state_estimate = self.predicted_state + kalman_gain * innovation
self.estimate_cov = (1 - kalman_gain * self.observation_model) * \
self.predicted_estimate_cov
def filter_update(self, observation: float):
self._predict()
self._update(observation)
return self.state_estimate
def filter(self, series: List[float]):
filtered_series = []
for observation in series:
filtered_series.append(self.filter_update(observation))
return filtered_series
data_dir = f'./data'
file_path = f'{data_dir}/NOK.csv'
price_col = 'Adj Close'
date_col = 'Date'
data = pd.read_csv(file_path)
series = data[price_col].values
kf = KalmanFilter1D(
init_state=series[0],
init_state_covariance=1,
state_transition_model=1,
process_noise_covariance=0.05,
observation_model=1,
observation_noise_covariance=3,
)
filtered1 = kf.filter(series)
data['my_kalman_filter'] = filtered1
kf2 = KalmanFilter(
initial_state_mean=data[price_col].values[0],
initial_state_covariance=1,
transition_matrices=1,
transition_covariance=0.05,
observation_matrices=1,
observation_covariance=3,
)
filtered2 = kf2.filter(series)[0]
data['kalman_lib_filter'] = filtered2
fig = go.Figure(
data=[
go.Scatter(
y=data[price_col],
x=data[date_col],
name='input_data'),
go.Scatter(
y=data['kalman_lib_filter'],
x=data[date_col],
name='kalman_lib_filter'),
go.Scatter(
y=data['my_kalman_filter'],
x=data[date_col],
name='my_kalman_filter'),
],
layout=go.Layout(
xaxis_title=date_col,
yaxis_title=price_col,
height=500, width=1000)
)
fig.show("notebook")